"""
Phase 2b: Extended Decomposition & Debt Simulations
=====================================================
Four additional analyses for the paper:

1. Health spending decomposition (Z -> health, Z -> non-health expenditure)
2. Structural Bohn stress test (same sample comparison)
3. Forward debt dynamics simulation (2024-2060)
4. Pension reform interaction (does reform weaken the spending channel?)

Input:  fiscal_dominance/data/processed/fiscal_panel.csv
        multilateral/data/processed/full_panel.csv (for health_exp_gdp)
Output: fiscal_dominance/output/tables/phase2b_*.csv
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
FD_DIR = PROJECT_DIR / "fiscal_dominance"
MULTILATERAL_DIR = PROJECT_DIR / "multilateral"
PROCESSED_DIR = FD_DIR / "data" / "processed"
TABLE_DIR = FD_DIR / "output" / "tables"

sys.path.insert(0, str(PROJECT_DIR / "multilateral" / "src"))
from model import PanelGLS


def fit_and_report(y, X, entity_ids, time_ids, feature_names, label):
    """Fit PanelGLS and return summary DataFrame."""
    model = PanelGLS()
    model.fit(y, X, entity_ids, time_ids)
    print(f"\n{'=' * 70}")
    print(f"  {label}")
    print(f"  N={model.n_obs:,}, {model.n_countries} countries, "
          f"R²={model.r_squared:.4f}, rho={model.rho:.3f}")
    print(f"{'=' * 70}")
    model.summary(feature_names=feature_names)
    result_df = model.to_dataframe(feature_names=feature_names)
    result_df['model'] = label
    result_df['n_obs'] = model.n_obs
    result_df['n_countries'] = model.n_countries
    result_df['r_squared'] = model.r_squared
    result_df['rho'] = model.rho
    return model, result_df


def main():
    print("=" * 70)
    print("PHASE 2b: Extended Decomposition & Debt Simulations")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "fiscal_panel.csv")
    print(f"Loaded fiscal panel: {len(df):,} obs, {df['iso3'].nunique()} countries")

    # Merge health_exp_gdp from full_panel
    fp = pd.read_csv(MULTILATERAL_DIR / "followup" / "data" / "processed" / "full_panel.csv",
                      usecols=['iso3', 'year', 'health_exp_gdp'])
    df = df.merge(fp, on=['iso3', 'year'], how='left')
    print(f"health_exp_gdp after merge: {df['health_exp_gdp'].notna().sum():,} obs, "
          f"{df.loc[df['health_exp_gdp'].notna(), 'iso3'].nunique()} countries")

    all_results = []
    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    decomp_controls = ['debt_lag', 'kaopen', 'log_rel_opw', 'nfa_gdp_lag']
    decomp_controls = [c for c in decomp_controls if c in df.columns]

    # =================================================================
    # 1. HEALTH SPENDING DECOMPOSITION
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  PART 1: Health vs Non-Health Expenditure Decomposition")
    print("=" * 70)

    # Compute non-health expenditure
    df['non_health_exp_gdp'] = df['govt_expenditure_gdp'] - df['health_exp_gdp']
    print(f"  non_health_exp_gdp: {df['non_health_exp_gdp'].notna().sum():,} obs")

    # 1a. Z -> health expenditure
    for dep_label, dep_col in [
        ('Health Expenditure/GDP', 'health_exp_gdp'),
        ('Non-Health Expenditure/GDP', 'non_health_exp_gdp'),
        ('Total Expenditure/GDP (health sample)', 'govt_expenditure_gdp'),
        ('Revenue/GDP (health sample)', 'govt_revenue_gdp'),
    ]:
        # Use the intersection sample (countries with health data)
        health_vars = demo_vars + decomp_controls
        est = df.dropna(subset=[dep_col, 'health_exp_gdp'] + health_vars).copy()

        if len(est) < 100:
            print(f"\n  Skipping {dep_label}: insufficient obs ({len(est)})")
            continue

        m, r = fit_and_report(
            est[dep_col].values, est[health_vars].values,
            est['iso3'].values, est['year'].values,
            health_vars, f"Health Decomp: Z -> {dep_label}"
        )
        all_results.append(r)

    # 1b. OADR version (more intuitive coefficients)
    df['old_dep_sq'] = df['old_dep'] ** 2
    oadr_controls = ['old_dep', 'old_dep_sq'] + decomp_controls

    print(f"\n{'=' * 70}")
    print("  OADR Specification (Health Decomposition)")
    print("=" * 70)

    oadr_summary = []
    for dep_label, dep_col in [
        ('Health Expenditure/GDP', 'health_exp_gdp'),
        ('Non-Health Expenditure/GDP', 'non_health_exp_gdp'),
        ('Total Expenditure/GDP', 'govt_expenditure_gdp'),
        ('Revenue/GDP', 'govt_revenue_gdp'),
    ]:
        est = df.dropna(subset=[dep_col, 'health_exp_gdp'] + oadr_controls).copy()
        if len(est) < 100:
            continue

        m, r = fit_and_report(
            est[dep_col].values, est[oadr_controls].values,
            est['iso3'].values, est['year'].values,
            oadr_controls, f"Health Decomp (OADR): {dep_label}"
        )
        all_results.append(r)

        idx_oadr = oadr_controls.index('old_dep')
        oadr_summary.append({
            'Dependent Variable': dep_label,
            'OADR coef': m.beta[idx_oadr],
            'OADR se': m.se[idx_oadr],
            'OADR p': m.pvalues[idx_oadr],
            'N': m.n_obs,
            'R²': m.r_squared,
        })

    if oadr_summary:
        oadr_df = pd.DataFrame(oadr_summary)
        print(f"\n{'=' * 70}")
        print("  OADR EFFECT ON SPENDING COMPONENTS (linear term)")
        print("=" * 70)
        print(oadr_df.to_string(index=False, float_format='%.4f'))
        oadr_df.to_csv(TABLE_DIR / "phase2b_health_decomp_oadr.csv", index=False)

    # =================================================================
    # 2. STRUCTURAL BOHN STRESS TEST
    #    Run primary balance Bohn on the SAME 83-country sample
    #    that has structural balance data
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  PART 2: Structural Bohn Stress Test (Same-Sample Comparison)")
    print("=" * 70)

    bohn_controls = ['output_gap_hp', 'govt_exp_gap']
    bohn_controls = [c for c in bohn_controls if c in df.columns and df[c].notna().sum() > 100]

    # Countries with structural balance data
    struct_countries = df.loc[df['structural_bal_gdp'].notna(), 'iso3'].unique()
    print(f"  Countries with structural balance: {len(struct_countries)}")

    df_struct = df[df['iso3'].isin(struct_countries)].copy()

    # Bohn interaction vars
    df_struct['debt_x_Z1'] = df_struct['debt_lag'] * df_struct['Z_1']
    df_struct['debt_x_Z2'] = df_struct['debt_lag'] * df_struct['Z_2']
    df_struct['debt_x_Z3'] = df_struct['debt_lag'] * df_struct['Z_3']
    interaction_vars = ['debt_x_Z1', 'debt_x_Z2', 'debt_x_Z3']

    bohn_vars = ['debt_lag'] + demo_vars + interaction_vars + bohn_controls

    # 2a. Primary balance on structural-balance sample
    est_pb = df_struct.dropna(subset=['primary_bal_gdp'] + bohn_vars).copy()
    if len(est_pb) >= 100:
        m_pb, r_pb = fit_and_report(
            est_pb['primary_bal_gdp'].values, est_pb[bohn_vars].values,
            est_pb['iso3'].values, est_pb['year'].values,
            bohn_vars, "Stress Test: Primary Bal Bohn (structural sample)"
        )
        all_results.append(r_pb)

    # 2b. Structural balance on same sample
    est_sb = df_struct.dropna(subset=['structural_bal_gdp'] + bohn_vars).copy()
    if len(est_sb) >= 100:
        m_sb, r_sb = fit_and_report(
            est_sb['structural_bal_gdp'].values, est_sb[bohn_vars].values,
            est_sb['iso3'].values, est_sb['year'].values,
            bohn_vars, "Stress Test: Structural Bal Bohn (same sample)"
        )
        all_results.append(r_sb)

    # Compare
    if len(est_pb) >= 100 and len(est_sb) >= 100:
        print(f"\n{'=' * 70}")
        print("  SAME-SAMPLE COMPARISON")
        print("=" * 70)
        idx = bohn_vars.index('debt_lag')
        print(f"  Primary balance Bohn:    beta={m_pb.beta[idx]:.4f} (p={m_pb.pvalues[idx]:.4f}), "
              f"N={m_pb.n_obs}, {m_pb.n_countries} countries")
        print(f"  Structural balance Bohn: beta={m_sb.beta[idx]:.4f} (p={m_sb.pvalues[idx]:.4f}), "
              f"N={m_sb.n_obs}, {m_sb.n_countries} countries")
        diff = m_sb.beta[idx] - m_pb.beta[idx]
        print(f"  Difference (structural - primary): {diff:.4f}")
        print(f"  Interpretation: {'Automatic stabilizers mask fiscal loosening' if diff < 0 else 'Results consistent across measures'}")

    # =================================================================
    # 3. FORWARD DEBT DYNAMICS SIMULATION
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  PART 3: Forward Debt Dynamics Simulation (2024-2060)")
    print("=" * 70)

    # Estimate the expenditure-revenue gap as a function of OADR
    # From phase2: expenditure OADR coef ~120, revenue OADR coef ~48
    # Gap = expenditure - revenue, so gap OADR coef ~72
    # But let's estimate it directly
    df['fiscal_gap'] = df['govt_expenditure_gdp'] - df['govt_revenue_gdp']
    gap_vars = ['old_dep', 'old_dep_sq'] + [c for c in decomp_controls if c != 'debt_lag']
    est_gap = df.dropna(subset=['fiscal_gap'] + gap_vars).copy()

    m_gap = None
    if len(est_gap) >= 200:
        m_gap, r_gap = fit_and_report(
            est_gap['fiscal_gap'].values, est_gap[gap_vars].values,
            est_gap['iso3'].values, est_gap['year'].values,
            gap_vars, "Gap Model: Z -> (Expenditure - Revenue)"
        )
        all_results.append(r_gap)

    # Also estimate r-g model for simulation
    rg_vars = demo_vars + [c for c in decomp_controls if c != 'debt_lag']
    est_rg = df.dropna(subset=['r_minus_g'] + rg_vars).copy()
    m_rg = None
    if len(est_rg) >= 200:
        m_rg = PanelGLS()
        m_rg.fit(est_rg['r_minus_g'].values, est_rg[rg_vars].values,
                 est_rg['iso3'].values, est_rg['year'].values)
        print(f"\n  r-g model for simulation: N={m_rg.n_obs}, R²={m_rg.r_squared:.4f}")

    # Load future demographics
    fp_full = pd.read_csv(MULTILATERAL_DIR / "followup" / "data" / "processed" / "full_panel.csv")

    # Simulation parameters
    sim_countries = ['JPN', 'ITA', 'USA', 'FRA', 'DEU', 'GBR', 'ESP',
                     'KOR', 'CHN', 'BRA', 'IND', 'POL', 'THA', 'ZAF', 'MEX']
    sim_years = list(range(2024, 2061))
    control_means = {c: df[c].mean() for c in decomp_controls if c in df.columns}

    sim_rows = []
    for iso3 in sim_countries:
        # Starting conditions (latest observed)
        latest = df[(df['iso3'] == iso3) & df['govt_debt_gdp'].notna()].sort_values('year').tail(1)
        if len(latest) == 0:
            continue

        debt_t = latest['govt_debt_gdp'].values[0]
        start_year = int(latest['year'].values[0])

        # Country-specific controls (held constant)
        ctrl = {}
        for c in decomp_controls:
            if c == 'debt_lag':
                continue
            val = latest[c].values[0] if c in latest.columns and latest[c].notna().values[0] else control_means.get(c, 0)
            ctrl[c] = val

        for year in sim_years:
            # Get future demographics
            yr_demo = fp_full[(fp_full['iso3'] == iso3) & (fp_full['year'] == year)]
            if len(yr_demo) == 0:
                continue

            oadr = yr_demo['old_dep'].values[0] if 'old_dep' in yr_demo.columns and yr_demo['old_dep'].notna().values[0] else np.nan
            z_vals = {}
            for zv in demo_vars:
                if zv in yr_demo.columns and yr_demo[zv].notna().values[0]:
                    z_vals[zv] = yr_demo[zv].values[0]

            if np.isnan(oadr) or len(z_vals) < 3:
                continue

            # Project fiscal gap from OADR
            gap_pred = np.nan
            if m_gap is not None:
                x_gap = []
                for v in gap_vars:
                    if v == 'old_dep':
                        x_gap.append(oadr)
                    elif v == 'old_dep_sq':
                        x_gap.append(oadr ** 2)
                    else:
                        x_gap.append(ctrl.get(v, 0))
                if not any(np.isnan(x) for x in x_gap):
                    gap_pred = m_gap.constant + np.dot(m_gap.beta, x_gap)

            # Project r-g from Z polynomials
            rg_pred = np.nan
            if m_rg is not None:
                x_rg = []
                for v in rg_vars:
                    if v in z_vals:
                        x_rg.append(z_vals[v])
                    else:
                        x_rg.append(ctrl.get(v, 0))
                if not any(np.isnan(x) for x in x_rg):
                    rg_pred = m_rg.constant + np.dot(m_rg.beta, x_rg)

            # Debt dynamics: debt(t+1) = debt(t) * (1 + r-g/100) + fiscal_gap
            # Simplified: debt(t+1) = debt(t) + (r-g)/100 * debt(t) + primary_deficit
            # where primary_deficit ≈ fiscal_gap (expenditure - revenue)
            if not np.isnan(gap_pred) and not np.isnan(rg_pred):
                interest_accumulation = (rg_pred / 100) * debt_t
                debt_t_new = debt_t + interest_accumulation + gap_pred
            elif not np.isnan(gap_pred):
                debt_t_new = debt_t + gap_pred
            else:
                debt_t_new = debt_t

            sim_rows.append({
                'iso3': iso3,
                'year': year,
                'old_dep': oadr,
                'proj_debt_gdp': debt_t_new,
                'proj_r_minus_g': rg_pred,
                'proj_fiscal_gap': gap_pred,
                'debt_from_rg': (rg_pred / 100) * debt_t if not np.isnan(rg_pred) else np.nan,
                'debt_from_gap': gap_pred,
            })

            debt_t = max(debt_t_new, 0)  # floor at 0

    sim_df = pd.DataFrame(sim_rows)
    if len(sim_df) > 0:
        sim_df.to_csv(TABLE_DIR / "phase2b_debt_simulation.csv", index=False)
        print(f"\n  Saved debt simulation: {len(sim_df)} rows")

        # Print key years
        for show_year in [2030, 2040, 2050, 2060]:
            yr = sim_df[sim_df['year'] == show_year].sort_values('proj_debt_gdp', ascending=False)
            if len(yr) == 0:
                continue
            print(f"\n  Projected Debt/GDP in {show_year}:")
            print(yr[['iso3', 'old_dep', 'proj_debt_gdp', 'proj_fiscal_gap', 'proj_r_minus_g']]
                  .head(15).to_string(index=False, float_format='%.1f'))

        # Trajectory for key countries
        print(f"\n{'=' * 70}")
        print("  Debt Trajectories (selected countries)")
        print("=" * 70)
        milestone_years = [2024, 2030, 2035, 2040, 2045, 2050, 2055, 2060]
        trajectory = sim_df[sim_df['year'].isin(milestone_years)].pivot(
            index='iso3', columns='year', values='proj_debt_gdp'
        )
        if len(trajectory) > 0:
            print(trajectory.to_string(float_format='%.0f'))
            trajectory.to_csv(TABLE_DIR / "phase2b_debt_trajectories.csv")

    # =================================================================
    # 4. PENSION REFORM INTERACTION
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  PART 4: Pension Reform Interaction")
    print("=" * 70)

    # Major pension reform countries and approximate reform years
    # (parametric or systemic reforms that changed benefit formulas)
    reform_events = {
        'SWE': 1998,  # NDC system
        'POL': 1999,  # NDC + funded pillar
        'LVA': 1996,  # NDC system
        'ITA': 1995,  # Dini reform (NDC)
        'DEU': 2001,  # Riester reform
        'FRA': 2003,  # Fillon reform
        'HUN': 1998,  # Mixed system
        'CHL': 1981,  # Fully funded (pre-sample, always reformed)
        'AUS': 1992,  # Superannuation
        'JPN': 2004,  # Macroeconomic indexing
        'GBR': 2007,  # Turner reforms
        'KOR': 2007,  # National Pension reform
        'CZE': 2013,  # Parametric
        'ESP': 2011,  # Parametric
        'GRC': 2010,  # Major parametric
        'PRT': 2007,  # Sustainability factor
    }

    # Create reform dummy: 1 if country has reformed AND year >= reform year
    df['pension_reform'] = 0.0
    for iso3, reform_year in reform_events.items():
        mask = (df['iso3'] == iso3) & (df['year'] >= reform_year)
        df.loc[mask, 'pension_reform'] = 1.0

    n_reform = df['pension_reform'].sum()
    n_reform_countries = df.loc[df['pension_reform'] == 1, 'iso3'].nunique()
    print(f"  Reform dummy: {n_reform:.0f} obs across {n_reform_countries} countries")

    # 4a. Does reform weaken the OADR -> expenditure link?
    df['oadr_x_reform'] = df['old_dep'] * df['pension_reform']
    reform_vars = ['old_dep', 'old_dep_sq', 'pension_reform', 'oadr_x_reform'] + \
                  [c for c in decomp_controls if c != 'debt_lag']

    for dep_label, dep_col in [
        ('Expenditure/GDP', 'govt_expenditure_gdp'),
        ('Revenue/GDP', 'govt_revenue_gdp'),
        ('Fiscal Gap', 'fiscal_gap'),
    ]:
        if dep_col not in df.columns:
            continue
        est = df.dropna(subset=[dep_col] + reform_vars).copy()
        if len(est) < 100:
            print(f"  Skipping {dep_label}: insufficient obs")
            continue

        m, r = fit_and_report(
            est[dep_col].values, est[reform_vars].values,
            est['iso3'].values, est['year'].values,
            reform_vars, f"Reform Interaction: OADR -> {dep_label}"
        )
        all_results.append(r)

        idx_int = reform_vars.index('oadr_x_reform')
        idx_oadr = reform_vars.index('old_dep')
        print(f"\n  >>> OADR effect on {dep_label}:")
        print(f"      Unreformed: {m.beta[idx_oadr]:.2f}")
        print(f"      Reform interaction: {m.beta[idx_int]:.2f} (p={m.pvalues[idx_int]:.4f})")
        print(f"      Reformed total: {m.beta[idx_oadr] + m.beta[idx_int]:.2f}")

    # 4b. Reform effect on Bohn coefficient
    bohn_controls_avail = ['output_gap_hp', 'govt_exp_gap']
    bohn_controls_avail = [c for c in bohn_controls_avail if c in df.columns and df[c].notna().sum() > 100]

    df['debt_x_reform'] = df['debt_lag'] * df['pension_reform']
    bohn_reform_vars = ['debt_lag', 'pension_reform', 'debt_x_reform'] + \
                       demo_vars + bohn_controls_avail

    est_br = df.dropna(subset=['primary_bal_gdp'] + bohn_reform_vars).copy()
    if len(est_br) >= 200:
        m_br, r_br = fit_and_report(
            est_br['primary_bal_gdp'].values, est_br[bohn_reform_vars].values,
            est_br['iso3'].values, est_br['year'].values,
            bohn_reform_vars, "Reform: Bohn with Reform Interaction"
        )
        all_results.append(r_br)

        idx_debt = bohn_reform_vars.index('debt_lag')
        idx_reform_int = bohn_reform_vars.index('debt_x_reform')
        print(f"\n  >>> Bohn coefficient:")
        print(f"      Unreformed: {m_br.beta[idx_debt]:.4f}")
        print(f"      Reform interaction: {m_br.beta[idx_reform_int]:.4f} (p={m_br.pvalues[idx_reform_int]:.4f})")
        print(f"      Reformed total: {m_br.beta[idx_debt] + m_br.beta[idx_reform_int]:.4f}")

    # =================================================================
    # Save all results
    # =================================================================
    if all_results:
        results_df = pd.concat(all_results, ignore_index=True)
        results_df.to_csv(TABLE_DIR / "phase2b_decomposition_results.csv", index=False)
        print(f"\n{'=' * 70}")
        print(f"Saved: {TABLE_DIR / 'phase2b_decomposition_results.csv'}")
        print(f"  {len(results_df)} rows across {results_df['model'].nunique()} models")

    # =================================================================
    # MASTER SUMMARY
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  MASTER SUMMARY — Key Numbers for the Paper")
    print("=" * 70)

    print("\n  1. HEALTH DECOMPOSITION (OADR linear coefficient):")
    if oadr_summary:
        for row in oadr_summary:
            sig = '***' if row['OADR p'] < 0.01 else '**' if row['OADR p'] < 0.05 else '*' if row['OADR p'] < 0.1 else 'ns'
            print(f"     {row['Dependent Variable']:<35} {row['OADR coef']:>8.1f} (p={row['OADR p']:.4f}) {sig}")

    print("\n  2. STRUCTURAL BOHN (same-sample):")
    if len(est_pb) >= 100 and len(est_sb) >= 100:
        idx = bohn_vars.index('debt_lag')
        print(f"     Primary balance:   beta = {m_pb.beta[idx]:.4f} (p={m_pb.pvalues[idx]:.4f})")
        print(f"     Structural balance: beta = {m_sb.beta[idx]:.4f} (p={m_sb.pvalues[idx]:.4f})")

    print("\n  3. DEBT TRAJECTORIES (projected debt/GDP):")
    if len(sim_df) > 0:
        for iso3 in ['JPN', 'ITA', 'USA', 'KOR', 'CHN']:
            row_2060 = sim_df[(sim_df['iso3'] == iso3) & (sim_df['year'] == 2060)]
            row_2024 = sim_df[(sim_df['iso3'] == iso3) & (sim_df['year'] == 2024)]
            if len(row_2060) > 0 and len(row_2024) > 0:
                d0 = row_2024['proj_debt_gdp'].values[0]
                d1 = row_2060['proj_debt_gdp'].values[0]
                print(f"     {iso3}: {d0:.0f}% -> {d1:.0f}% ({d1-d0:+.0f}pp)")

    print("\n  4. PENSION REFORM:")
    if len(est_br) >= 200:
        idx_debt = bohn_reform_vars.index('debt_lag')
        idx_reform_int = bohn_reform_vars.index('debt_x_reform')
        print(f"     Unreformed Bohn beta: {m_br.beta[idx_debt]:.4f}")
        print(f"     Reform interaction:   {m_br.beta[idx_reform_int]:.4f} (p={m_br.pvalues[idx_reform_int]:.4f})")

    return results_df if all_results else pd.DataFrame()


if __name__ == "__main__":
    results = main()
